{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Understanding the PyTorch parallel scan, with RNNs/Mamba in mind"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "TODO's : \n",
    "- sections\n",
    "- proof, in a separate file"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This file gives a detailed background and explanation of the `pscan.py` file. The goal is to impement a parallel scan in PyTorch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First of all, what is a <b>scan</b> ?\n",
    "\n",
    "A scan is defined as an operation that takes as input an array and procudes an array as output. You can see that it is quite general.\n",
    "\n",
    " A simple and well-known example of a scan is the <i>cumulative sum of an array</i> :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1,  3,  6, 10])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.tensor([1, 2, 3, 4])\n",
    "\n",
    "torch.cumsum(X, dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For the rest of this document, we will denote `L` as the length of our input array `X`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The most basic way to implement a scan is to use a simple for loop :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1,  3,  6, 10])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = torch.zeros_like(X)\n",
    "\n",
    "cumulative_sum = 0\n",
    "for t in range(X.size(0)):\n",
    "    cumulative_sum += X[t]\n",
    "    Y[t] = cumulative_sum\n",
    "\n",
    "Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Quite simple for now, right ?\n",
    "To setup our notations, we will keep this example for a bit.\n",
    "\n",
    "Here, we use an accumulator, `cumulative_sum`, which we update as we go through the input array `X`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An equivalent way to rewrite the above code is :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1,  3,  6, 10])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = torch.zeros_like(X)\n",
    "\n",
    "Y[0] = X[0]\n",
    "for t in range(1, X.size(0)):\n",
    "    Y[t] = Y[t-1] + X[t]\n",
    "\n",
    "Y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While not explicitly present, the accumulator is still here, in `Y`. It is propagated with the recurrence relation `Y[t] = Y[t-1] + X[t]`.\n",
    "\n",
    "We can visualize what happens with a simple diagram, which should remind you a bit about RNNs :"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/cumsum_rnns.jpg\" alt=\"cumulative sum\" width=\"1000\" height=\"300\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In some sense, `Y` plays the role of the hidden state, while `X` plays the role of the input : as we process the input, we keep and update a running hidden state."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that this method of computation, which uses a sequential loop, induces `L` sequential steps of computations in order to compute the whole output `Y`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, is it possible to <b>parallelize</b> this scan operation ?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is just what does the <b>parallel scan</b>. \n",
    "\n",
    "Let's stay with the simple example of our cumulative sum. In fact, let's simplify it even more : let's say with just want to compute the sum of our input array `X`.\n",
    "Again, we could come up with a for loop to count the elements. But can't we parallelize this computation ?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yes, and it is best visualized with this simple tree :"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/reduction_tree.jpg\" alt=\"cumulative sum\" width=\"1000\" height=\"600\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we assume that the length of our array is a power of 2, then we just have to group the elements 2 by 2, add them, and repeat until we are left with one element, our result. If `L` is `2**d`, then we will need to do `math.log2(L)=d` sequential steps to compute the sum of the array. That's a major speedup over the `L` steps of the naive for loop."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How could we implement this in Python ?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # input array\n",
    "L = X.size(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can <i>group the elements by two</i> using :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1, 2],\n",
       "        [3, 4],\n",
       "        [5, 6],\n",
       "        [7, 8]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa = X.view(L//2, 2)\n",
    "Xa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have pairs of elements. We can do :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([1, 3, 5, 7])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa[:, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([2, 4, 6, 8])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "to access the elements from the two groups. We can see that, to compute the first step, we simply need to sum these two arrays :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 3,  7, 11, 15])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa[:, 0] + Xa[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Yay ! We have just accomplished our first step ! Now, we just need to repeat what we've just done.\n",
    "\n",
    "Note that we could do `Xa = Xa[:, 0] + Xa[:, 1]` and then we would just need to repeat the previous step with `Xa`. But :\n",
    "- this will allocate extra memory spaces for storing the result of the first step (`Xa` currently shares the data as `X`, doing this would compute the sum and store its output in a new memory space).\n",
    "- we will reuse some of these values, later, for the full scan operation.\n",
    "\n",
    "Hence, we will work by updating `X` <b>in-place</b>, by doing :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 3,  7, 11, 15])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa[:, 1] += Xa[:, 0]\n",
    "Xa[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is what we want. Now, we will repeat the step we have just done, but on `Xa[:, 1]` rather than on `X`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "Xa = Xa[:, 1]\n",
    "Xa = Xa.view(Xa.size(0)//2, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Again, `Xa` is split in two groups :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 3,  7],\n",
       "        [11, 15]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We sum these two groups, and put the result in the second half of `Xa`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([10, 26])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa[:, 1] += Xa[:, 0]\n",
    "Xa[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have two elements left ! This means, only one more step to go (because `math.log2(2) = 1` of course !) :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([36])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Xa = Xa[:, 1]\n",
    "Xa = Xa.view(Xa.size(0)//2, 2)\n",
    "\n",
    "Xa[:, 1] += Xa[:, 0]\n",
    "Xa[:, 1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Which is the result we want !"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To recap, here is the full code :"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 1,  3,  3, 10,  5, 11,  7, 36])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # input array\n",
    "L = X.size(0)\n",
    "\n",
    "Xa = X\n",
    "\n",
    "for k in range(int(math.log2(L))):\n",
    "    T = 2 * (Xa.size(0) // 2)\n",
    "\n",
    "    # split into 2 groups of pairs of elements\n",
    "    Xa = Xa.view(T//2, 2) \n",
    "\n",
    "    # for each pair, add the first to the second\n",
    "    Xa[:, 1].add_(Xa[:, 0])\n",
    "\n",
    "    # change the view for the next iteration\n",
    "    Xa = Xa[:, 1]\n",
    "\n",
    "X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our result is found in the last element of `X` (remember that all the `Xa`'s are views of our original tensor!)\n",
    "\n",
    "Let's look at how `X` evolved during this whole process : each line corresponds to the data of tensor `X` at a particular iteration."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/tensor_mem.jpeg\" alt=\"tensor evolution in memory\" width=\"1000\" height=\"500\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To help clarify, we can also draw on top of this the tree underlying the operation."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/tensor_mem_tree.jpeg\" alt=\"tensor evolution in memory\" width=\"1000\" height=\"500\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that having all these numbers is a bit cumbersome, as only a few of them changes at each iteration.\n",
    "\n",
    "So, for the rest of this document, we'll represent this with the following diagram :"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/tree_reduction.jpeg\" alt=\"tensor evolution in memory\" width=\"1000\" height=\"500\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can see that this is the exact same tree that was drawn in red in the above figure.\n",
    "\n",
    "It is essentially a simple tree again, **but** :\n",
    "- the placement of the nodes is important : see how the sum of each pair is placed on the element of the right.\n",
    "- at each level of the tree, we can easily see the exact state of the array : you just have, for each position, to look at the upmost element in the column."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ok, so we can compute the sum of an array in parallel. Now, what about the cumulative sum of an array ?\n",
    "\n",
    "If we look again at the previous tree, we can see that the values of some nodes are partial sums of our input array. For example, the 10 is the sum of elements 1 to 4.\n",
    "\n",
    "In fact, it's quite easy to see (and demonstrate) that each node holds the sum of the elements of `X` which are its children or more generally its descendants.\n",
    "\n",
    "We can see it even more clearly here :"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/tree_reduction_xs.jpeg\" alt=\"tensor evolution in memory\" width=\"1000\" height=\"550\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Consedering this, it's now becoming clearer that we will be able to reuse this tree to construct the cumulative sum of `X` : it is filled with partial sums of the input array. That's why we used in-place operations in the code above : after the process of computing the sum, we have left in `X` some values of this tree."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we look again at this tree, we see that that, for a given node that is <b>a left child</b>, if its parent is a right child, then adding the value of the left child (the \"left brother\" of its parent, if you will) to the original node will extend its partial sum. It's best visualized here, with adding $x_{1->4}$ to $x_{5->6}$, which will give us $x_{1->6}$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/downsweep_ex.jpg\" alt=\"downsweep, 1 step\" width=\"1000\" height=\"550\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This was step 1. Let's say we do this computation from top to bottom, for all nodes that are <b>left children</b> :"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/downsweep.jpg\" alt=\"downsweep, 1 step\" width=\"1000\" height=\"550\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is good, some values are already cumulative sums ($x_{1->3}$ and $x_{1->7}$), some are partials ($x_{3->5}$) and somes are unchanged.\n",
    "\n",
    "But, and this is were it is a bit tricky, if we account for the fact that this isn't a static tree but a representation of the evolution of our tensor in memory, we know that we execute this sequence of computations (called the <b>down-sweep</b>, because we go from top to bottom) <b>after</b> the up-sweep (where we went from bottom to top). Hence, if we look at the fourth value of our tensor for example, during the down-sweep, it is already equal to $x_{1->4}$ (remember that we only do in-place operations). Same for the last value : it isn't equal to $x_8$, but to $x_{1->8}$ after the <b>up-sweep</b>."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/downsweep_updated.jpg\" alt=\"downsweep, 1 step\" width=\"1000\" height=\"550\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is not yet a proof (the proof is left to the reader 😉), but essentially, we can see that the descendants of the two nodes we're summing would combine and we would get the sum of all the descendants up to the <b>left child</b> we are considering. \n",
    "\n",
    "For a full and rigourous proof, we would need to actually modify a bit the formulation of our algorithm (but it would stay strictly equivalent) : we consider two trees, one built by the up-sweep operation, and the other built by the down-sweep operation, as follows :\n",
    "\n",
    "- start by placing as the root the top element of the up sweep tree (ie, the sum)\n",
    "- then, pass as its right child its value, and as its left child its value minus the value of the corresponding right child in the up-sweep tree\n",
    "- repeat, until the number of nodes equals the length of the input array/\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is what we get for the example :"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/up_down_trees.jpg\" alt=\"downsweep, 1 step\" width=\"1600\" height=\"550\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the rule we followed to construct the down-sweep tree :"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<p align=\"center\">\n",
    "    <img src=\"assets/down_sweep_rule.jpg\" alt=\"downsweep, 1 step\" width=\"550\" height=\"550\" alt=\"python mamba\"/>\n",
    "</p>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For example, how did we compute the value $28$ during the last step of the down-sweep ? We substracted from the parent, $36$, the value of the right child of the corresponding node in the up-sweep tree : $8$. We gives us $36-8 = 28$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We could then show, by induction, that : <i>After a complete down-sweep, each node of the (down-sweep) tree contains the sum of all the leaf values that <b>precede it</b></i> (not necessarily descendants)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "proof"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is the same approach followed by Guy E. Blelloch in <i>Prefix Sums and Their Applications</i>. However, as you can see, I chose to followed a more algorithm-friendly approach and keep this math-friendly approach for the proof. Blelloch follows the principle of an exclusive scan, while here, it is an inclusive scan that is presented (the cumulative sum take into account only elements that are strictly before it)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "How do this translates in Python ?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After the up-sweep, what do we have ?\n",
    "We have :\n",
    "- `X`, our input array whose some values are replaced.\n",
    "- `Xa`, which points to the rightmost value of `X`.\n",
    "\n",
    "At each step, we're going to \"grow back\" `Xa` to what it were before squeezing it during the up-sweep. So at each step, it will double in size, just like, at each step of the up-sweep, its size was halved.\n",
    "It's actually a bit tricky to get `Xa` to \"grow back\" to what it was during the up-sweep. For the up-sweep, we just had to split the inputs into 2 groups, and do `Xa = Xa[:, 1]` to select half of the inputs. But here, we actually need to make `Xa` bigger, not smaller. So we have to tell it what indices to look for in `X`, and these are :\n",
    "\n",
    "`Xa = X[:, :, 2**k-1:L:2**k]`\n",
    "\n",
    "where `k` is the down-sweep step, which goes from `math.log2(L) - 1` down to `0`.\n",
    "I won't go into the details here, I think that writing it on a scratchpad helps a lot more to see why is this the case than wordy explanations.\n",
    "\n",
    "For the sake of clarity, if we apply this formula with `L=8`, we have :\n",
    "- for the first step of the down-sweep, "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "left to write :\n",
    "- python cum sum\n",
    "- introduction de A\n",
    "- code pscan.py forward\n",
    "\n",
    "- backward ?\n",
    "- code pscan.py backward"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch23",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
